# Copyright 2020 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
import os
import sys
import tempfile
import unittest

# TODO: No idea why pytype cannot find names from this module.
# pytype: disable=name-error
import iree.compiler.tf

if not iree.compiler.tf.is_available():
  print(f"Skipping test {__file__} because the IREE TensorFlow compiler "
        f"is not installed")
  sys.exit(0)

import tensorflow as tf


class SimpleArithmeticModule(tf.Module):

  @tf.function(input_signature=[
      tf.TensorSpec([4], tf.float32),
      tf.TensorSpec([4], tf.float32)
  ])
  def simple_mul(self, a, b):
    return a * b

  @tf.function(input_signature=[
      tf.TensorSpec([128, 3072], tf.float32),
      tf.TensorSpec([3072, 256], tf.float32),
  ])
  def simple_matmul(self, a, b):
    return tf.matmul(a, b)


# TODO(laurenzo): More test cases needed (may need additional files).
# Specifically, figure out how to test v1 models.
class TfCompilerTest(tf.test.TestCase):

  def testImportSavedModel(self):
    import_mlir = iree.compiler.tf.compile_saved_model(
        self.smdir, import_only=True, output_generic_mlir=True).decode("utf-8")
    self.assertIn("sym_name = \"simple_matmul\"", import_mlir)

  def testCompileSavedModel(self):
    binary = iree.compiler.tf.compile_saved_model(
        self.smdir, target_backends=iree.compiler.tf.DEFAULT_TESTING_BACKENDS)
    logging.info("Compiled len: %d", len(binary))
    self.assertIn(b"simple_matmul", binary)
    self.assertIn(b"simple_mul", binary)

  def testCompileModule(self):
    binary = iree.compiler.tf.compile_module(
        self.m, target_backends=iree.compiler.tf.DEFAULT_TESTING_BACKENDS)
    logging.info("Compiled len: %d", len(binary))
    self.assertIn(b"simple_matmul", binary)
    self.assertIn(b"simple_mul", binary)

  @classmethod
  def setUpClass(cls):
    cls.m = SimpleArithmeticModule()
    cls.tempdir = tempfile.TemporaryDirectory()
    cls.smdir = os.path.join(cls.tempdir.name, "arith.sm")
    tf.saved_model.save(
        cls.m,
        cls.smdir,
        options=tf.saved_model.SaveOptions(save_debug_info=True))

  @classmethod
  def tearDownClass(cls):
    cls.tempdir.cleanup()


if __name__ == "__main__":
  logging.basicConfig(level=logging.DEBUG)
  tf.test.main()
